{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# default_exp problem_types.seq2seq_text\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test setup\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "from m3tl.test_base import TestBase\n",
    "from m3tl.input_fn import train_eval_input_fn\n",
    "from m3tl.test_base import test_top_layer\n",
    "test_base = TestBase()\n",
    "params = test_base.params\n",
    "\n",
    "hidden_dim = params.bert_config.hidden_size\n",
    "\n",
    "train_dataset = train_eval_input_fn(params=params)\n",
    "one_batch = next(train_dataset.as_numpy_iterator())\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Seq2seq text generation(seq2seq_text)\n",
    "\n",
    "This module includes neccessary part to register Seq2seq text generation problem type.\n",
    "\n",
    "THIS IS NOT IMPLEMENTED YET!!!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports and utils\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "import pickle\n",
    "from typing import Dict, List, Tuple\n",
    "\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from m3tl.base_params import BaseParams\n",
    "from m3tl.utils import (LabelEncoder, get_label_encoder_save_path,\n",
    "                        load_transformer_tokenizer, need_make_label_encoder)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Top Layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "\n",
    "class Seq2Seq(tf.keras.Model):\n",
    "    def __init__(self, params: BaseParams, problem_name: str, input_embeddings: tf.keras.layers.Layer):\n",
    "        super(Seq2Seq, self).__init__(name=problem_name)\n",
    "        # self.params = params\n",
    "        # self.problem_name = problem_name\n",
    "        # # if self.params.init_weight_from_huggingface:\n",
    "        # #     self.decoder = load_transformer_model(\n",
    "        # #         self.params.transformer_decoder_model_name,\n",
    "        # #         self.params.transformer_decoder_model_loading)\n",
    "        # # else:\n",
    "        # #     self.decoder = load_transformer_model(\n",
    "        # #         self.params.bert_decoder_config, self.params.transformer_decoder_model_loading)\n",
    "\n",
    "        # # TODO: better implementation\n",
    "        # logging.warning(\n",
    "        #     'Seq2Seq model is not well supported yet. Bugs are expected.')\n",
    "        # config = self.params.bert_decoder_config\n",
    "        # # some hacky approach to share embeddings from encoder to decoder\n",
    "        # word_embedding_weight = input_embeddings.word_embeddings\n",
    "        # self.vocab_size = word_embedding_weight.shape[0]\n",
    "        # self.share_embedding_layer = TFSharedEmbeddings(\n",
    "        #     vocab_size=word_embedding_weight.shape[0], hidden_size=word_embedding_weight.shape[1])\n",
    "        # self.share_embedding_layer.build([1])\n",
    "        # self.share_embedding_layer.weight = word_embedding_weight\n",
    "        # # self.decoder = TFBartDecoder(\n",
    "        # #     config=config, embed_tokens=self.share_embedding_layer)\n",
    "        # self.decoder = TFBartDecoderForConditionalGeneration(\n",
    "        #     config=config, embedding_layer=self.share_embedding_layer)\n",
    "        # self.decoder.set_bos_id(self.params.bos_id)\n",
    "        # self.decoder.set_eos_id(self.params.eos_id)\n",
    "\n",
    "        # self.metric_fn = tf.keras.metrics.SparseCategoricalAccuracy(\n",
    "        #     name='{}_acc'.format(self.problem_name))\n",
    "        raise NotImplementedError\n",
    "\n",
    "    def _seq2seq_label_shift_right(self, labels: tf.Tensor, eos_id: int) -> tf.Tensor:\n",
    "        batch_eos_ids = tf.fill([tf.shape(labels)[0], 1], eos_id)\n",
    "        batch_eos_ids = tf.cast(batch_eos_ids, dtype=tf.int64)\n",
    "        decoder_lable = labels[:, 1:]\n",
    "        decoder_lable = tf.concat([decoder_lable, batch_eos_ids], axis=1)\n",
    "        return decoder_lable\n",
    "\n",
    "    def call(self,\n",
    "             inputs: Tuple[Dict[str, Dict[str, tf.Tensor]], Dict[str, Dict[str, tf.Tensor]]],\n",
    "             mode: str):\n",
    "        features, hidden_features = inputs\n",
    "        encoder_mask = features['model_input_mask']\n",
    "\n",
    "        if mode == tf.estimator.ModeKeys.PREDICT:\n",
    "            input_ids = None\n",
    "            decoder_padding_mask = None\n",
    "        else:\n",
    "            input_ids = features['%s_label_ids' % self.problem_name]\n",
    "            decoder_padding_mask = features['{}_mask'.format(\n",
    "                self.problem_name)]\n",
    "\n",
    "        if mode == tf.estimator.ModeKeys.PREDICT:\n",
    "            return self.decoder.generate(eos_token_id=self.params.eos_id, encoder_hidden_states=hidden_features['seq'])\n",
    "        else:\n",
    "            decoder_output = self.decoder(input_ids=input_ids,\n",
    "                                          encoder_hidden_states=hidden_features['seq'],\n",
    "                                          encoder_padding_mask=encoder_mask,\n",
    "                                          decoder_padding_mask=decoder_padding_mask,\n",
    "                                          decode_max_length=self.params.decode_max_seq_len,\n",
    "                                          mode=mode)\n",
    "            loss = decoder_output.loss\n",
    "            logits = decoder_output.logits\n",
    "            self.add_loss(loss)\n",
    "            decoder_label = self._seq2seq_label_shift_right(\n",
    "                features['%s_label_ids' % self.problem_name], eos_id=self.params.eos_id)\n",
    "            acc = self.metric_fn(decoder_label, logits)\n",
    "            self.add_metric(acc)\n",
    "            return logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get or make label encoder function\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def seq2seq_text_get_or_make_label_encoder_fn(params: BaseParams, problem: str, mode: str, label_list: List[str], *args, **kwargs) -> LabelEncoder:\n",
    "\n",
    "    le_path = get_label_encoder_save_path(params=params, problem=problem)\n",
    "    if need_make_label_encoder(mode=mode, le_path=le_path, overwrite=kwargs['overwrite']):\n",
    "        # fit and save label encoder\n",
    "        label_encoder = load_transformer_tokenizer(\n",
    "            params.transformer_decoder_tokenizer_name, params.transformer_decoder_tokenizer_loading)\n",
    "        pickle.dump(label_encoder, open(le_path, 'wb'))\n",
    "        params.set_problem_info(problem=problem, info_name='num_classes', info=len(label_encoder.encode_dict))\n",
    "    else:\n",
    "        label_encoder = pickle.load(open(le_path, 'rb'))\n",
    "\n",
    "    return label_encoder"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Label handing function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def pad_wrapper(inp, target_len=90):\n",
    "    if len(inp) >= target_len:\n",
    "        return inp[:target_len]\n",
    "    else:\n",
    "        return inp + [0]*(target_len - len(inp))\n",
    "\n",
    "def seq2seq_text_label_handling_fn(target, label_encoder=None, tokenizer=None, decoding_length=None, *args, **kwargs):\n",
    "    target = [label_encoder.bos_token] + \\\n",
    "        target + [label_encoder.eos_token]\n",
    "    label_dict = label_encoder(\n",
    "        target, add_special_tokens=False, is_split_into_words=True)\n",
    "    label_id = label_dict['input_ids']\n",
    "    label_mask = label_dict['attention_mask']\n",
    "    label_id = pad_wrapper(label_id, decoding_length)\n",
    "    label_mask = pad_wrapper(label_mask, decoding_length)\n",
    "    return label_id, label_mask\n",
    "\n"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 2
}
